{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2023-05-09T10:21:52.789161200Z",
     "start_time": "2023-05-09T10:21:52.684159700Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append('..')\n",
    "from common.time_layers import *\n",
    "from common.np import *\n",
    "from common.base_model import BaseModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "class BetterRnnlm:\n",
    "    def __init__(self, vocab_size=10000, wordvec_size=650, hidden_size=650, dropout_ratio=0.5):\n",
    "        V, D, H = vocab_size, wordvec_size, hidden_size\n",
    "        rn = np.random.randn\n",
    "\n",
    "        embed_W = (rn(V, D) / 100).astype('f')\n",
    "        lstm_Wx1 = (rn(D, 4 * H) / np.sqrt(D)).astype('f')\n",
    "        lstm_Wh1 = (rn(H, 4 * H) / np.sqrt(H)).astype('f')\n",
    "        lstm_b1 = np.zeros(4 * H).astype('f')\n",
    "        lstm_Wx2 = (rn(H, 4 * H) / np.sqrt(H)).astype('f')\n",
    "        lstm_Wh2 = (rn(H, 4 * H) / np.sqrt(H)).astype('f')\n",
    "        lstm_b2 = np.zeros(4 * H).astype('f')\n",
    "        affine_b = np.zeros(V).astype('f')\n",
    "\n",
    "        self.layers = [\n",
    "            TimeEmbedding(embed_W),\n",
    "            TimeDropout(dropout_ratio),\n",
    "            TimeLSTM(lstm_Wx1, lstm_Wh1, lstm_b1, stateful=True),\n",
    "            TimeDropout(dropout_ratio),\n",
    "            TimeLSTM(lstm_Wx2, lstm_Wh2, lstm_b2, stateful=True),\n",
    "            TimeDropout(dropout_ratio),\n",
    "            TimeAffine(embed_W.T, affine_b)  # 权重共享!!\n",
    "        ]\n",
    "\n",
    "        self.loss_layer = TimeSoftmaxWithLoss()\n",
    "        self.lstm_layers = [self.layers[2], self.layers[4]]\n",
    "        self.drop_layers = [self.layers[1], self.layers[3],\n",
    "                            self.layers[5]]\n",
    "        self.params, self.grads = [], []\n",
    "        for layer in self.layers:\n",
    "            self.params += layer.params\n",
    "            self.grads += layer.grads\n",
    "\n",
    "    def predict(self, xs, train_flg=False):\n",
    "        for layer in self.drop_layers:\n",
    "            layer.train_flg = train_flg\n",
    "        for layer in self.layers:\n",
    "            xs = layer.forward(xs)\n",
    "        return xs\n",
    "\n",
    "    def forward(self, xs, ts, train_flg=True):\n",
    "        score = self.predict(xs, train_flg)\n",
    "        loss = self.loss_layer.forward(score, ts)\n",
    "        return loss\n",
    "\n",
    "    def backward(self, dout=1):\n",
    "        dout = self.loss_layer.backward(dout)\n",
    "        for layer in reversed(self.layers):\n",
    "            dout = layer.backward(dout)\n",
    "        return dout\n",
    "\n",
    "    def reset_state(self):\n",
    "        for layer in self.lstm_layers:\n",
    "            layer.reset_state()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T10:32:04.478897100Z",
     "start_time": "2023-05-09T10:32:04.469897900Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "from common import config\n",
    "from common.optimizer import SGD\n",
    "from common.trainer import RnnlmTrainer\n",
    "from common.util import eval_perplexity\n",
    "from dataset import ptb"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T10:29:23.571995900Z",
     "start_time": "2023-05-09T10:29:23.037993400Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "# 设定超参数\n",
    "batch_size = 20\n",
    "wordvec_size = 650\n",
    "hidden_size = 650\n",
    "time_size = 35\n",
    "lr = 20.0\n",
    "max_epoch = 40\n",
    "max_grad = 0.25\n",
    "dropout = 0.5"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T10:29:33.016344900Z",
     "start_time": "2023-05-09T10:29:33.007346900Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "# 读入训练数据\n",
    "corpus, word_to_id, id_to_word = ptb.load_data('train')\n",
    "corpus_val, _, _ = ptb.load_data('val')\n",
    "corpus_test, _, _ = ptb.load_data('test')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T10:29:37.927148200Z",
     "start_time": "2023-05-09T10:29:37.888149900Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "vocab_size = len(word_to_id)\n",
    "xs = corpus[:-1]\n",
    "ts = corpus[1:]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T10:29:43.756331800Z",
     "start_time": "2023-05-09T10:29:43.747331100Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "model = BetterRnnlm(vocab_size, wordvec_size, hidden_size, dropout)\n",
    "optimizer = SGD(lr)\n",
    "trainer = RnnlmTrainer(model, optimizer)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T10:32:08.332695100Z",
     "start_time": "2023-05-09T10:32:08.007697Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "best_ppl = float('inf')\n",
    "for epoch in range(max_epoch):\n",
    "    trainer.fit(xs, ts, max_epoch=1, batch_size=batch_size, time_size=time_size, max_grad=max_grad)\n",
    "    model.reset_state()\n",
    "    ppl = eval_perplexity(model, corpus_val)\n",
    "    print('valid perplexity: ', ppl)\n",
    "    if best_ppl > ppl:\n",
    "        best_ppl = ppl\n",
    "        model.save_params()\n",
    "    else:\n",
    "        lr /= 4.0\n",
    "        optimizer.lr = lr\n",
    "    model.reset_state()\n",
    "    print('-' * 50)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "这个学习需要相当长的时间。在用 CPU 运行的情况下，需要 2 天左右；而如果用 GPU 运行，则能在 5 小时左右完成\n",
    "\"\"\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "执行上面的代码，困惑度平稳下降，最终在测试数据上获得了困惑度为75.76 的结果（每次运行结果不同）。考虑到改进前的 RNNLM 的困惑度约为 136，这个结果可以说提升很大。通过 LSTM 的多层化提高表现力，通过 Dropout 提高泛化能力，通过权重共享有效利用权重，从而实现了精度的大幅提高。"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
